import argparse
import os
import time
import copy
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from glob import glob
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from dtd import *
from albumentations.pytorch import ToTensorV2
import torchvision
import tempfile
from torch.cuda.amp import autocast, GradScaler  # need pytorch>1.6
from losses import DiceLoss, FocalLoss, SoftCrossEntropyLoss, LovaszLoss

Image.MAX_IMAGE_PIXELS = 1000000000000000


class TamperDataset(Dataset):
    # minq是jpeg压缩比率
    def __init__(self, roots, mode, minq=95, qtb=90, max_readers=64):
        self.envs = lmdb.open(roots, max_readers=max_readers, readonly=True, lock=False, readahead=False, meminit=False)
        with self.envs.begin(write=False) as txn:
            self.nSamples = int(txn.get('num-samples'.encode('utf-8')))
        self.max_nums = self.nSamples
        self.minq = minq
        self.mode = mode
        with open('qt_table.pk', 'rb') as fpk:
            pks = pickle.load(fpk)
        self.pks = {}
        for k, v in pks.items():
            self.pks[k] = torch.LongTensor(v)
        with open('pks/' + roots + '_%d.pk' % minq, 'rb') as f:
            self.record = pickle.load(f)
        self.hflip = torchvision.transforms.RandomHorizontalFlip(p=1.0)
        self.vflip = torchvision.transforms.RandomVerticalFlip(p=1.0)
        self.totsr = ToTensorV2()
        self.toctsr = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                      torchvision.transforms.Normalize(mean=(0.485, 0.455, 0.406),
                                                                                       std=(0.229, 0.224, 0.225))])

    def __len__(self):
        return self.max_nums

    def __getitem__(self, index):
        with self.envs.begin(write=False) as txn:
            img_key = 'image-%09d' % index
            imgbuf = txn.get(img_key.encode('utf-8'))
            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            im = Image.open(buf)
            lbl_key = 'label-%09d' % index
            lblbuf = txn.get(lbl_key.encode('utf-8'))
            mask = (cv2.imdecode(np.frombuffer(lblbuf, dtype=np.uint8), 0) != 0).astype(np.uint8)
            H, W = mask.shape
            record = self.record[index]
            choicei = len(record) - 1
            q = int(record[-1])
            use_qtb = self.pks[q]
            if choicei > 1:
                q2 = int(record[-3])
                use_qtb2 = self.pks[q2]
            if choicei > 0:
                q1 = int(record[-2])
                use_qtb1 = self.pks[q1]
            mask = self.totsr(image=mask.copy())['image']
            with tempfile.NamedTemporaryFile(delete=True) as tmp:
                im = im.convert("L")
                if choicei > 1:
                    im.save(tmp, "JPEG", quality=q2)
                    im = Image.open(tmp)
                if choicei > 0:
                    im.save(tmp, "JPEG", quality=q1)
                    im = Image.open(tmp)
                im.save(tmp, "JPEG", quality=q)
                jpg = jpegio.read(tmp.name)
                dct = jpg.coef_arrays[0].copy()
                im = im.convert('RGB')
            return {
                'image': self.toctsr(im),
                'label': mask.long(),
                'rgb': np.clip(np.abs(dct), 0, 20),
                'q': use_qtb,
                'i': q
            }


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class IOUMetric:
    def __init__(self, num_classes=10):
        self.num_classes = num_classes
        self.hist = np.zeros((num_classes, num_classes))

    def _fast_hist(self, label_pred, label_true):
        mask = (label_true >= 0) & (label_true < self.num_classes)
        hist = np.bincount(
            self.num_classes * label_true[mask].astype(int) +
            label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
        return hist

    def add_batch(self, predictions, gts):
        for lp, lt in zip(predictions, gts):
            self.hist += self._fast_hist(lp.flatten(), lt.flatten())

    def evaluate(self):
        acc = np.diag(self.hist).sum() / self.hist.sum()
        acc_cls = np.diag(self.hist) / self.hist.sum(axis=1)
        acc_cls = np.nanmean(acc_cls)
        iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))
        mean_iu = np.nanmean(iu)
        freq = self.hist.sum(axis=1) / self.hist.sum()
        fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
        return acc, acc_cls, iu, mean_iu, fwavacc


def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter("[%(asctime)s][%(filename)s][%(levelname)s] %(message)s")
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])
    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    return logger


#
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', type=str, default='./')  # root to the dir of lmdb files
parser.add_argument('--pth', type=str, default='dtd.pth')
parser.add_argument('--lmdb_name', type=str, default='DocTamperV1-FCD')

parser.add_argument('--minq', type=int, default=75)
args = parser.parse_args()

model = seg_dtd('', 2).cuda()
model = torch.nn.DataParallel(model)

param = dict()
param['model_name'] = 'textTamper'
param['epochs'] = 100
param['batch_size'] = 8
param['iter_inter'] = 5
param['save_log_dir'] = f'./trainStore'
param['save_ckpt_dir'] = f'./trainStore'
param['save_epoch'] = 10
param['T0'] = 10
param['load_ckpt_dir'] = None
train_data = TamperDataset(args.data_root + 'DocTamperV1-SCD', False, minq=args.minq)
valid_data = TamperDataset(args.data_root + args.lmdb_name, False, minq=args.minq)


def train_net_qyl(param, model, train_data, valid_data, plot=False, device='cuda'):
    # 初始化参数
    model_name = param['model_name']  # 字符串，指定模型的名称，用于日志记录和模型保存的文件名
    epochs = param['epochs']  # 整数，指定训练的总轮数。
    batch_size = param['batch_size']  # 整数，指定每个批次的样本数。
    iter_inter = param['iter_inter']  # 整数，用于控制日志记录的频率，即每多少个批次记录一次日志。
    save_log_dir = param['save_log_dir']  # 字符串，指定日志文件的保存目录。
    save_ckpt_dir = param['save_ckpt_dir']  # 字符串，指定模型检查点文件的保存目录。
    load_ckpt_dir = param.get('load_ckpt_dir', None)  # 字符串，指定加载模型检查点文件的目录，用于模型继续训练。
    save_epoch = param['save_epoch']  # 列表，指定哪些轮次的模型需要保存。
    T0 = param['T0']  # 整数，是余弦退火学习率调度器中的一个参数，用于控制第一个周期的迭代次数。
    scaler = GradScaler()

    # 网络参数
    train_data_size = train_data.__len__()
    valid_data_size = valid_data.__len__()
    c, y, x = train_data.__getitem__(0)['image'].shape
    train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True, num_workers=4)
    valid_loader = DataLoader(dataset=valid_data, batch_size=batch_size, shuffle=False, num_workers=4)
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=5e-4)
    # optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=momentum, weight_decay=weight_decay)
    # optimizer=Ranger(model.parameters(),lr=1e-3)
    # scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=T0, T_mult=2, eta_min=1e-5,
                                                                     last_epoch=-1)
    # scheduler=ShopeeScheduler(optimizer,**scheduler_params)
    # criterion = nn.CrossEntropyLoss(reduction='mean').to(device)
    # DiceLoss_fn = DiceLoss(mode='multiclass')
    LovaszLoss_fn = LovaszLoss(mode='multiclass')
    SoftCrossEntropy_fn = SoftCrossEntropyLoss(smooth_factor=0.1)
    # logger = inial_logger(os.path.join(save_log_dir, time.strftime("%m-%d %H:%M:%S", time.localtime()) +'_'+model_name+ '.log'))
    logger = get_logger(
        os.path.join(save_log_dir, time.strftime("%m-%d %H:%M:%S", time.localtime()) + '_' + model_name + '.log'))
    # 主循环
    train_loss_total_epochs, valid_loss_total_epochs, epoch_lr = [], [], []
    train_loader_size = train_loader.__len__()
    valid_loader_size = valid_loader.__len__()
    best_iou = 0
    best_epoch = 0
    best_mode = copy.deepcopy(model)
    epoch_start = 0
    # load_ckpt_dir = None
    if load_ckpt_dir is not None:
        ckpt = torch.load(load_ckpt_dir)
        epoch_start = ckpt['epoch']
        model.load_state_dict(ckpt['state_dict'])
        optimizer.load_state_dict(ckpt['optimizer'])

    logger.info(
        'Total Epoch:{} Image_size:({}, {}) Training num:{}  Validation num:{}'.format(epochs, x, y, train_data_size,
                                                                                       valid_data_size))
    #

    for epoch in range(epoch_start, epochs):
        epoch_start = time.time()
        # 训练阶段
        model.train()
        train_epoch_loss = AverageMeter()
        train_iter_loss = AverageMeter()
        for batch_idx, batch_samples in enumerate(tqdm(train_loader)):
            data, target, dct_coef, qs, q = batch_samples['image'], batch_samples['label'], batch_samples['rgb'], \
                batch_samples['q'], batch_samples['i']
            data, target, dct_coef, qs = data.to(device), target.to(device), dct_coef.to(device), qs.unsqueeze(1).to(device)
            with autocast():  # need pytorch>1.6
                pred = model(data, dct_coef, qs)
                loss = LovaszLoss_fn(pred, target) + SoftCrossEntropy_fn(pred, target)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            scheduler.step(epoch + batch_idx / train_loader_size)
            image_loss = loss.item()
            train_epoch_loss.update(image_loss)
            train_iter_loss.update(image_loss)
            if batch_idx % iter_inter == 0:
                spend_time = time.time() - epoch_start
                logger.info('[train] epoch:{} iter:{}/{} {:.2f}% lr:{:.6f} loss:{:.6f} ETA:{}min'.format(
                    epoch, batch_idx, train_loader_size, batch_idx / train_loader_size * 100,
                    optimizer.param_groups[-1]['lr'],
                    train_iter_loss.avg, spend_time / (batch_idx + 1) * train_loader_size // 60 - spend_time // 60))
                train_iter_loss.reset()

        # scheduler.step()
        # 验证阶段
        model.eval()
        valid_epoch_loss = AverageMeter()
        valid_iter_loss = AverageMeter()
        iou = IOUMetric(2)
        with torch.no_grad():
            for batch_idx, batch_samples in enumerate(valid_loader):
                data, target = batch_samples['image'], batch_samples['label']
                data, target = data.to(device),target.to(device)
                pred = model(data)
                loss = LovaszLoss_fn(pred, target) + SoftCrossEntropy_fn(pred, target)
                pred = pred.cpu().data.numpy()
                pred = np.argmax(pred, axis=1)
                iou.add_batch(pred, target.cpu().data.numpy())
                #
                image_loss = loss.item()
                valid_epoch_loss.update(image_loss)
                valid_iter_loss.update(image_loss)
                # if batch_idx % iter_inter == 0:
                #     logger.info('[val] epoch:{} iter:{}/{} {:.2f}% loss:{:.6f}'.format(
                #         epoch, batch_idx, valid_loader_size, batch_idx / valid_loader_size * 100, valid_iter_loss.avg))
            val_loss = valid_iter_loss.avg
            acc, acc_cls, iu, mean_iu, fwavacc = iou.evaluate()
            logger.info('[val] epoch:{} iou:{}'.format(epoch, iu))

        # 保存loss、lr
        train_loss_total_epochs.append(train_epoch_loss.avg)
        valid_loss_total_epochs.append(valid_epoch_loss.avg)
        epoch_lr.append(optimizer.param_groups[0]['lr'])
        # 保存模型
        if epoch in save_epoch[T0]:
            torch.save(model.state_dict(), '{}/cosine_epoch{}.pth'.format(save_ckpt_dir, epoch))
        state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
        filename = os.path.join(save_ckpt_dir, 'checkpoint-latest.pth')
        torch.save(state, filename)  # pytorch1.6会压缩模型，低版本无法加载
        # 保存最优模型
        if iu[1] > best_iou:  # train_loss_per_epoch valid_loss_per_epoch
            state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
            filename = os.path.join(save_ckpt_dir, 'checkpoint-best.pth')
            torch.save(state, filename)
            best_iou = iu[1]
            best_mode = copy.deepcopy(model)
            logger.info('[save] Best Model saved at epoch:{} ============================='.format(epoch))
        # scheduler.step()

    return best_mode, model



train_net_qyl(param, model, train_data=train_data, valid_data=valid_data)
